#!/usr/bin/env python3

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""
    (NOTE: To use this class, please follow the tutorial here:
    http://parl.ai/static/docs/tutorial_worlds.html#multiprocessed-pytorch-dataloader)

"""
from .teachers import FixedDialogTeacher
from parlai.scripts.build_pytorch_data import build_data
from .agents import get_agent_module
import json
import math
import random
from functools import wraps
import importlib
from functools import lru_cache
try:
    import torch  # noqa: F401
except Exception as e:
    raise ImportError('Need to install Pytorch: go to pytorch.org')
from torch.utils.data import ConcatDataset, Dataset, DataLoader, sampler
from torch.multiprocessing import Lock, Value
import ctypes
from threading import Thread, Condition, RLock


'''
    Maps episode length to dictionary with following keys:
        current_idx: which episode in the list are we at (if simply indexing
            into list)
        ep_list: list of episodes of the length of the key
        bucket_complete: if there are no more episodes left to consider in
            the bucket
'''
# Maps episode length to list of episodes
length_to_eps = {}
# List of batches if popping batches
batches = []
# If all episodes have been loaded into memory
load_complete = Value(ctypes.c_bool, False)
# Lock to access batches
batches_lock = Lock()
# Lock to access length_to_eps
cache_lock = Lock()
# Lock for condition variables
fill_cache_lock = RLock()
# Condition notifying Loader to add to cache
add_to_cache_cv = Condition(lock=fill_cache_lock)
# Condition notifying teacher that cache has episodes
cache_filled_cv = Condition(lock=fill_cache_lock)


def batch_cache(function):
    max_cache_size = 10000  # Max unseen eps
    min_cache_size = 1000  # Min unseen eps

    def get_cache_size():
        '''Returns number of available episodes '''
        return sum(
            len(v['ep_list']) - v['current_idx']for k, v in length_to_eps.items()
        )

    def get_available_buckets(bsz):
        '''Returns buckets where there are enough episodes for a batch'''
        if load_complete.value:
            return {
                k: v
                for k, v in length_to_eps.items()
                if not v['bucket_complete'] or len(v['ep_list']) - v['current_idx'] > 0
            }
        else:
            return {
                k: v
                for k, v in length_to_eps.items()
                if len(v['ep_list']) - v['current_idx'] >= bsz
            }

    def reset():
        '''Resets the indices into the buckets'''
        with cache_lock:
            for idx in length_to_eps:
                length_to_eps[idx]['current_idx'] = 0
                length_to_eps[idx]['bucket_complete'] = False

    def consolidate(caller):
        '''Consolidate remaining episodes into batches'''
        load_complete.value = True
        bsz = caller.bsz
        batch = []
        sorted_lengths = sorted(length_to_eps.keys())
        with cache_lock:
            if caller.batch_cache_type == 'index':
                for length in sorted_lengths:
                    current_idx = length_to_eps[length]['current_idx']
                    ep_list = length_to_eps[length]['ep_list']
                    unseen_eps = ep_list[current_idx:]
                    length_to_eps[length]['ep_list'] = ep_list[:current_idx]
                    batch = unseen_eps + batch
                    while len(batch) >= bsz:
                        length_to_eps[length]['ep_list'] += batch[:bsz]
                        batch = batch[bsz:]
                if len(batch) > 0:
                    length_to_eps[-1] = {
                        'current_idx': 0,
                        'ep_list': batch,
                        'bucket_complete': False
                    }
            elif caller.batch_cache_type == 'pop':
                for length in sorted_lengths:
                    batch += length_to_eps[length]['ep_list']
                with batches_lock:
                    while len(batch) >= bsz:
                        batches.append(batch[:bsz])
                        batch = batch[bsz:]
                if len(batch) > 0:
                    with batches_lock:
                        batches.append(batch)

    def flatten(l):
        '''Helper function for flattening a list'''
        return [item for sublist in l for item in sublist]

    def put_in_cache(ep_idx, episode, caller):
        '''Put episode `ep_idx` into cache'''
        length = episode['text'].count(' ')
        lengths = [length] + flatten([
            [length + i, length + (i * -1)]
            for i in range(1, caller.batch_length_range)
        ])
        lengths = [max(i, 1) for i in lengths]
        in_cache = False
        for l in lengths:
            if l in length_to_eps:
                with cache_lock:
                    length_to_eps[l]['ep_list'] += [(ep_idx, episode)]
                in_cache = True
                break
        if not in_cache:
            with cache_lock:
                length_to_eps[length] = {
                    'current_idx': 0,
                    'ep_list': [(ep_idx, episode)],
                    'bucket_complete': False
                }
        if ep_idx == caller.dataset.num_episodes() - 1:
            consolidate(caller)
            with add_to_cache_cv:
                cache_filled_cv.notify_all()

    @wraps(function)
    def wrapper(*args):
        caller = args[0]
        batch_cache_type = caller.batch_cache_type
        bsz = caller.bsz
        if batch_cache_type == 'none' or not caller.datatype.startswith('train'):
            return function(*args)
        # If Loader, put episodes in cache
        if isinstance(caller, LoaderProcess):
            with add_to_cache_cv:
                while (get_cache_size() >= max_cache_size and
                        len(get_available_buckets(bsz)) > 0):
                    cache_filled_cv.notify_all()
                    add_to_cache_cv.wait()
            idx_and_batch = function(*args)
            if idx_and_batch is None:
                return None
            for ep_index, ep in idx_and_batch[1]:
                put_in_cache(ep_index, ep, caller)
            return idx_and_batch
        # If teacher, return batch of episodes
        else:
            teacher = caller
            num_batches = teacher.num_batches
            while True:
                with cache_filled_cv:
                    while (not load_complete.value and
                            (get_cache_size() <= min_cache_size or
                                len(get_available_buckets(bsz)) == 0)):
                        add_to_cache_cv.notify()
                        cache_filled_cv.wait()
                        available_buckets = get_available_buckets(bsz)
                if load_complete.value and batch_cache_type == 'pop':
                    return teacher.batch_idx + 1, random.choice(batches)
                batch = None
                available_buckets = get_available_buckets(bsz)
                if len(available_buckets) != 0:
                    # Pick length index at random
                    length = random.choice(list(available_buckets.keys()))
                    with cache_lock:
                        current_idx = length_to_eps[length]['current_idx']
                        ep_list = length_to_eps[length]['ep_list']
                        num_eps = len(ep_list)
                        if num_eps - current_idx >= bsz:
                            if batch_cache_type == 'pop':
                                batch = ep_list[:bsz]
                                length_to_eps[length]['ep_list'] = ep_list[bsz:]
                            else:
                                batch = ep_list[current_idx: current_idx + bsz]
                                length_to_eps[length]['current_idx'] = (
                                    current_idx + bsz
                                )
                        elif load_complete.value and num_eps > 0:
                            if batch_cache_type == 'pop':
                                batch = ep_list
                            elif num_eps - current_idx > 0:
                                batch = ep_list[current_idx:]
                                length_to_eps[length]['current_idx'] = num_eps - 1
                            length_to_eps[length]['bucket_complete'] = True

                if batch is not None:
                    if batch_cache_type == 'pop':
                        with batches_lock:
                            batches.append(batch)
                    elif teacher.batch_idx + 1 >= num_batches:
                        reset()
                    return teacher.batch_idx + 1, batch

    return wrapper


# Get Datasets from the options
def get_dataset_classes(opt):
    """ To use a custom dataset (as opposed to the StreamDataset or ParlAIDataset),
        you can subclass the pytorch Dataset class and specify its
        location on the command line.

        For example, the VQA v1 task provides a custom dataset, which can
        be specified on the command line as follows:
        ``-pytd vqa_v1:VQADataset``

        Note that if the dataset is named ``DefaultDataset``, then you do
        not need to specify its name following the colon; e.g., it
        would just be:
        ``-pytd vqa_v1``
    """
    if 'stream' in opt.get('datatype'):
        default_dataset = StreamDataset
    else:
        default_dataset = ParlAIDataset
    dataset_name = opt.get('pytorch_teacher_dataset')
    task_name = opt.get('pytorch_teacher_task')
    datasets = []
    if task_name is not None:
        datasets += [
            (default_dataset, default_collate, task)
            for task in task_name.split(',')
        ]
    if not dataset_name:
        return datasets
    sps = [d.strip() for d in dataset_name.split(',')]
    for sp in sps:
        full_task_name = sp
        repo = 'parlai'
        if sp.startswith('internal:'):
            # To switch to local repo, useful for non-public projects
            # (make a directory called 'parlai_internal' with your private agents)
            repo = 'parlai_internal'
            sp = sp[9:]
        sp = sp.split(':')
        if '.' in sp[0]:
            module_name = sp[0]
        else:
            dataset = sp[0].lower()
            module_name = '{}.tasks.{}.agents'.format(repo, dataset)
        if len(sp) > 1:
            sp[1] = sp[1][0].upper() + sp[1][1:]
            dataset = sp[1]
            if '.' not in sp[0] and 'Dataset' not in dataset:
                # Reformat from underscore to CamelCase and append "Dataset" to
                # class name by default if a complete path is not given.
                words = dataset.split('_')
                teacher_name = ''
                for w in words:
                    teacher_name += (w[0].upper() + w[1:])
                dataset = teacher_name + 'Dataset'
        else:
            dataset = 'DefaultDataset'
        my_module = importlib.import_module(module_name)
        dataset_class = getattr(my_module, dataset)

        collate = default_collate
        if hasattr(dataset_class, 'collate'):
            collate = dataset_class.collate
        elif opt.get('model', False):
            agent_class = get_agent_module(opt.get('model'))
            if hasattr(agent_class, 'collate'):
                collate = agent_class.collate
        datasets.append((dataset_class, collate, full_task_name))
    return datasets


class LoaderProcess(Thread):
    """A background process that submits jobs to the DataLoader
       to load examples into cache
    """
    def __init__(self, opt):
        super().__init__(daemon=True)
        dataset_classes = get_dataset_classes(opt)
        if len(dataset_classes) > 1:
            datasets = []
            for class_name, collate_fn, task_name in dataset_classes:
                opt['pytorch_teacher_task'] = task_name
                opt['task'] = task_name
                datasets.append(class_name(opt))
                self.collate = collate_fn
            self.dataset = ParlAIConcatDataset(datasets)
        else:
            class_name, self.collate, task_name = dataset_classes[0]
            self.dataset = class_name(opt)
        self.bsz = opt.get('batchsize', 1)
        self.num_workers = opt.get('num_workers', 4)
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.bsz,
            shuffle=False,
            sampler=sampler.SequentialSampler(self.dataset),
            num_workers=self.num_workers,
            collate_fn=self.collate,
            pin_memory=False,
            drop_last=False,
        )
        self.datatype = opt.get('datatype')
        self.data = enumerate(self.dataloader)
        self.batch_cache_type = opt.get('batch_sort_cache')
        self.batch_length_range = opt.get('batch_length_range')

    def run(self):
        while True:
            idx_and_batch = self.load_next()
            if idx_and_batch is None:
                return

    @batch_cache
    def load_next(self):
        try:
            return next(self.data)
        except StopIteration:
            return None


# Default collate function (for how to prepare a batch)
def default_collate(batch):
    new_batch = []
    for b in batch:
        idx = b[0]
        if type(b[1]) is list:
            ep = b[1][0]
        else:
            ep = b[1]
        new_batch.append((idx, ep))
    return new_batch


class StreamDataset(Dataset):
    """A Pytorch Dataset utilizing streaming"""
    def __init__(self, opt):
        self.opt = opt
        self.datatype = opt.get('datatype')
        self.datafile = build_data(self.opt)
        self.data_gen = self._data_generator(self.datafile)
        self.length_datafile = self.datafile + ".length"
        self.training = self.datatype.startswith('train')
        self._load_lens()

    def __getitem__(self, index):
        while True:
            idx, ep = next(self.data_gen)
            if idx == index:
                return (index, ep)

    def __len__(self):
        return self.num_episodes()

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _data_generator(self, datafile):
        while True:
            for idx, episode in self._read_episode(self.datafile):
                yield idx, episode

    def _read_episode(self, datafile):
        read = open(datafile)
        episode = []
        for idx, line in enumerate(read):
            example = json.loads(line)
            episode.append(example)
            if example['episode_done']:
                yield idx, episode
                episode = []
        read.close()

    def num_episodes(self):
        return self.num_eps

    def num_examples(self):
        return self.num_exs


class ParlAIDataset(Dataset):
    """A Pytorch Dataset, for random sampling"""
    def __init__(self, opt):
        self.opt = opt
        self.datatype = opt.get('datatype')
        self.datafile = build_data(self.opt)
        self._setup_data()
        self.length_datafile = self.datafile + ".length"
        self.training = self.datatype.startswith('train')
        self._load_lens()

    def __getitem__(self, index):
        return index, self.data[index]

    def __len__(self):
        return self.num_episodes()

    def _load_lens(self):
        with open(self.length_datafile) as length:
            lengths = json.load(length)
            self.num_eps = lengths['num_eps']
            self.num_exs = lengths['num_exs']

    def _setup_data(self):
        self.data = []
        with open(self.datafile) as f:
            for line in f:
                self.data.append(json.loads(line))

    def num_episodes(self):
        return self.num_eps

    def num_examples(self):
        return self.num_exs


class ParlAIConcatDataset(ConcatDataset):
    """Override to set num_eps and num_exs"""

    @lru_cache(maxsize=1)
    def num_episodes(self):
        return sum(d.num_episodes() for d in self.datasets)

    @lru_cache(maxsize=1)
    def num_examples(self):
        return sum(d.num_examples() for d in self.datasets)


class PytorchDataTeacher(FixedDialogTeacher):

    def __init__(self, opt, shared=None):
        opt['batch_sort'] = False
        super().__init__(opt, shared)
        self.use_batch_act = self.bsz > 1
        self.num_workers = opt['numworkers']
        self.batch_cache_type = opt.get('batch_sort_cache')
        # One can specify a collate function to use for preparing a batch
        self.opt = opt.copy()
        self.is_shared = shared is not None
        dataset_classes = self.get_dataset_class(opt)

        if not shared:
            streaming = False
            if len(dataset_classes) > 1:
                datasets = []
                for class_name, collate_fn, task_name in dataset_classes:
                    opt['pytorch_teacher_task'] = task_name
                    opt['task'] = task_name
                    datasets.append(class_name(opt))
                    streaming = streaming or (class_name == StreamDataset)
                    self.collate_fn = collate_fn
                self.dataset = ParlAIConcatDataset(datasets)
            else:
                class_name, self.collate_fn, task_name = dataset_classes[0]
                streaming = class_name == StreamDataset
                self.dataset = class_name(opt)
            self.streaming = 'stream' in self.datatype or streaming
            if self.streaming or not opt.get('shuffle'):
                data_sampler = sampler.SequentialSampler(self.dataset)
                pin_memory = False
            else:
                data_sampler = sampler.RandomSampler(self.dataset)
                pin_memory = True

            self.pytorch_dataloader = DataLoader(
                self.dataset,
                batch_size=self.bsz,
                sampler=data_sampler,
                num_workers=self.num_workers,
                collate_fn=self.collate_fn,
                pin_memory=pin_memory,
                drop_last=False,
            )
            self.lastYs = [None] * self.bsz
            if self.batch_cache_type != 'none':
                self.loader_process = LoaderProcess(opt)
                self.loader_process.start()
            self.data = enumerate(self.pytorch_dataloader)
        else:
            self.dataset = shared['dataset']
            self.pytorch_dataloader = shared['pytorch_dataloader']
            self.lastYs = shared['lastYs']
            self.data = shared['data']

        self.num_batches = math.ceil(self.dataset.num_episodes() / self.bsz)
        self.reset()

    def get_dataset_class(self, opt):
        return get_dataset_classes(opt)

    def reset(self):
        """Reset the dialog so that it is at the start of the epoch,
        and all metrics are reset.
        """
        super().reset()
        self.reset_data()

    def reset_data(self):
        if not self.is_shared:
            self.data = enumerate(self.pytorch_dataloader)
        self.lastY = None
        self.epochDone = False
        self.episode = None
        self.episode_done = True
        self.episode_idx = 0
        self.batch_idx = 0

    def share(self):
        shared = super().share()
        shared['pytorch_dataloader'] = self.pytorch_dataloader
        shared['dataset'] = self.dataset
        shared['data'] = self.data
        return shared

    def next_example(self):
        if self.epochDone:
            if not self.training:
                return {'episode_done': True, 'id': self.getID()}, True
            else:
                # Reset the data because it is streaming data
                self.reset_data()
        if self.episode_done:
            try:
                self.episode_idx, self.episode = next(self.data)
                self.entry_idx = 0
                epoch_done = False
            except StopIteration:
                ex = {'episode_done': True, 'id': self.getID()}
                epoch_done = True
        else:
            self.entry_idx += 1

        if not epoch_done:
            if self.collate_fn == default_collate:
                self.episode[self.entry_idx] = self.episode[self.entry_idx][1]
            ex = self.episode[self.entry_idx]
            self.episode_done = ex['episode_done']
            if (self.episode_done and
                    self.episode_idx + self.bsz >= self.num_episodes()):
                epoch_done = True
        return ex, epoch_done

    @batch_cache
    def get_next_batch(self):
        # employs a cache to see if there is a batch of equal size ready
        batch = next(self.data)
        return batch

    def next_batch(self):
        if self.epochDone:
            if not self.training:
                return [{'episode_done': True, 'id': self.getID()}] * self.bsz
            else:
                # Reset the data because it is streaming data
                self.reset_data()
        try:
            self.batch_idx, batch = self.get_next_batch()
            if self.collate_fn == default_collate:
                batch = [b[1] for b in batch]
            epoch_done = False
        except StopIteration:
            batch = [{'episode_done': True, 'id': self.getID()}] * self.bsz
            epoch_done = True
        if not epoch_done and self.batch_idx == self.num_batches:
            epoch_done = True
        self.epochDone = epoch_done
        return batch

    def num_episodes(self):
        """Get the number of episodes in this dataset."""
        return self.dataset.num_episodes()

    def num_examples(self):
        """Get the total number of examples in this dataset."""
        return self.dataset.num_examples()

    def act(self):
        """Send new dialog message."""
        action = super().act()
        self.lastY = action.get('labels', action.get('eval_labels', None))
        return action


class DefaultTeacher(PytorchDataTeacher):
    pass
